from enum import Enum
from typing import Optional, Any, Dict, List, Union
from pydantic import BaseModel, Field, field_validator, Extra
from abc import ABCMeta


class ChatCompletionRole(str, Enum):
    system = "system"
    assistant = "assistant"
    user = "user"
    function = "function"


class ChatCompletionFunctionCall(BaseModel):
    id: str = Field(..., description="The id of the function call.", examples=["call_abc123"])
    arguments: Dict[str, Any] = Field(
        ..., description="The arguments of the function call.", examples=[{"a": 1, "b": 2}]
    )
    name: str = Field(..., description="The name of the function.", examples=["plus_a_and_b"])


class ChatCompletionFunction(BaseModel):
    name: str = Field(..., description="The name of the function.", examples=["plus_a_and_b"])
    description: str = Field(..., description="The description of the function.", examples=["Add two numbers"])
    parameters: Dict = Field(
        ...,
        description="The function's parameters are represented as an object in JSON Schema format.",
        examples=[
            {
                "type": "object",
                "properties": {
                    "a": {"type": "number", "description": "The first number"},
                    "b": {"type": "number", "description": "The second number"},
                },
                "required": ["a", "b"],
            }
        ],
    )


class ChatCompletionMessage(BaseModel, metaclass=ABCMeta):
    content: Optional[str] = Field(None, description="The content of the message.", examples=["Hello!"])


class ChatCompletionSystemMessage(ChatCompletionMessage):
    role: ChatCompletionRole = Field(
        ...,
        Literal=ChatCompletionRole.system,
        description="The role of the message, which is always 'system' for a system message",
        examples=["system"],
    )


class ChatCompletionUserMessage(ChatCompletionMessage):
    role: ChatCompletionRole = Field(
        ...,
        Literal=ChatCompletionRole.user,
        description="The role of the message, which is always 'user' for a user message",
        examples=["user"],
    )


class ChatCompletionAssistantMessage(ChatCompletionMessage):
    role: ChatCompletionRole = Field(
        ...,
        Literal=ChatCompletionRole.assistant,
        description="The role of the message, which is always 'assistant' for an assistant message",
        examples=["assistant"],
    )
    function_calls: Optional[List[ChatCompletionFunctionCall]] = Field(
        None, description="The funcion calls requested by the assistant."
    )


class ChatCompletionFunctionMessage(ChatCompletionMessage):
    role: ChatCompletionRole = Field(
        ...,
        Literal=ChatCompletionRole.function,
        description="The role of the message, which is always 'function' for a function message",
        examples=["function"],
    )
    id: str = Field(
        ..., description="The corresponding id of the tool requested by the assistant.", examples=["call_abc123"]
    )


class ChatCompletionFinishReason(str, Enum):
    stop = "stop"
    length = "length"
    function_calls = "function_calls"
    recitation = "recitation"
    error = "error"
    unknown = "unknown"


class ChatCompletion(BaseModel):
    object: str = Field(
        "ChatCompletion",
        Literal="ChatCompletion",
        description="The object type, which is always 'ChatCompletion'.",
        examples=["ChatCompletion"],
    )
    finish_reason: ChatCompletionFinishReason = Field(
        ..., description="The reason why the generation is finished.", examples=["stop"]
    )
    message: ChatCompletionAssistantMessage = Field(..., description="The message generated by the assistant.")
    created_timestamp: int = Field(
        ..., description="The timestamp in milliseconds when the response is created.", examples=[1700000000000]
    )


# Chat Completion
# POST /v1/chat_completion
class ChatCompletionRequest(BaseModel):
    model_id: str = Field(
        ..., min_length=8, max_length=8, description="The chat completion model id.", examples=["abcdefgh"]
    )
    configs: Optional[Dict] = Field(None, description="The model configuration.", examples=[{"temperature": 0.5}])
    stream: bool = Field(
        False,
        description="Indicates whether the response should be streamed. "
        "If set to True, the response will be streamed using Server-Sent Events (SSE).",
        examples=[False],
    )
    messages: List[
        Union[
            ChatCompletionFunctionMessage,
            ChatCompletionAssistantMessage,
            ChatCompletionUserMessage,
            ChatCompletionSystemMessage,
        ]
    ] = Field(
        ...,
        description="The messages to be sent to the model.",
    )
    function_call: Optional[str] = Field(
        None,
        description="Controls whether a specific function is invoked by the model. "
        "If set to 'none', the model will generate a message without calling a function. "
        "If set to 'auto', the model can choose between generating a message or calling a function. "
        "Defining a specific function using {'name': 'my_function'} instructs the model to call that particular function. "
        "By default, 'none' is selected when there are no chat_completion_functions available, "
        "and 'auto' is selected when one or more chat_completion_functions are present.",
    )
    functions: Optional[List[ChatCompletionFunction]] = Field(None)

    class Config:
        extra = Extra.forbid

    @field_validator("messages", mode="before")
    def validate_message(cls, messages: List[Dict]):
        return [cls._convert_message(m) for m in messages]

    @staticmethod
    def _convert_message(message_data: Dict):
        role = message_data.get("role")
        if role == "system":
            return ChatCompletionSystemMessage(**message_data)
        elif role == "user":
            return ChatCompletionUserMessage(**message_data)
        elif role == "assistant":
            return ChatCompletionAssistantMessage(**message_data)
        elif role == "function":
            return ChatCompletionFunctionMessage(**message_data)
        else:
            raise ValueError(f"Invalid message role: {role}")
